/**
 * Copyright (c) 2016-present, Facebook, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef QUANT_DECODE_OP_H_
#define QUANT_DECODE_OP_H_

#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/tensor.h"
#include "caffe2/core/typeid.h"

namespace caffe2 {

namespace {

template <class CodebookT, class CodeT>
void Decode(
    const TensorCPU& codebook,
    const TensorCPU& codes,
    /* optional */ const TensorCPU* const decoded_grad,
    TensorCPU* const output,
    bool resizeOnly) {
  CAFFE_ENFORCE(codebook.IsType<CodebookT>());

  auto* cb_ptr = codebook.data<CodebookT>();
  int cb_size = codebook.size();

  CAFFE_ENFORCE(codes.IsType<CodeT>());
  auto* code_ptr = codes.data<CodeT>();

  if (decoded_grad == nullptr) {
    // Forward pass: decode and store codebook values in output.
    output->ResizeLike(codes);
    auto* out_ptr = output->mutable_data<CodebookT>();
    if (resizeOnly) {
      return;
    }

    int sz = output->size();
    for (int i = 0; i < sz; i++) {
      DCHECK_LE(*code_ptr, cb_size);
      *out_ptr++ = cb_ptr[*code_ptr++];
    }
  } else {
    // Backward pass: decode and accumulate gradient w.r.t. codebook values.
    CAFFE_ENFORCE_EQ(codes.size(), decoded_grad->size());
    auto* gradient_ptr = decoded_grad->data<CodebookT>();
    auto* const gradient_end = gradient_ptr + decoded_grad->size();

    CAFFE_ENFORCE_EQ(cb_size, output->size());
    auto* out_ptr = output->mutable_data<CodebookT>();
    while (gradient_ptr < gradient_end) {
      DCHECK_LE(*code_ptr, cb_size);
      out_ptr[*code_ptr++] += *gradient_ptr++;
    }
  }
}

#define REGISTER_DECODER(codebookType, codesType)                      \
  {                                                                    \
    {TypeMeta::Id<codebookType>(), TypeMeta::Id<codesType>()},         \
        [](const TensorCPU& codebook_,                                 \
           const TensorCPU& codes_,                                    \
           const TensorCPU* gradient_,                                 \
           TensorCPU* outDecoded_,                                     \
           bool resizeOnly_) {                                         \
          Decode<codebookType, codesType>(                             \
              codebook_, codes_, gradient_, outDecoded_, resizeOnly_); \
        }                                                              \
  }

inline void DecodeGeneral(
    const TensorCPU& codebook,
    const TensorCPU& codes,
    const TensorCPU* gradient,
    TensorCPU* outDecoded,
    bool resizeOnly) {
  const static std::map<
      std::pair<CaffeTypeId, CaffeTypeId>,
      std::function<void(
          const TensorCPU& codebook,
          const TensorCPU& codes,
          const TensorCPU* gradient,
          TensorCPU* outDecoded,
          bool resizeOnly)>>
      gDecoderMapper = {REGISTER_DECODER(float, uint8_t),
                        REGISTER_DECODER(float, uint16_t),
                        REGISTER_DECODER(float, int32_t)};

  gDecoderMapper.at({codebook.meta().id(), codes.meta().id()})(
      codebook, codes, gradient, outDecoded, resizeOnly);
}

} // namespace

// Decode tensors based on given codebook,
// The codebook is generated by model_quantize.py

enum class QuantDecodeRunTy {
  RUN_ALWAYS,
  RUN_ONCE,
};

template <QuantDecodeRunTy QuantDecodeRun>
class QuantDecodeOp final : public Operator<CPUContext> {
 public:
  USE_OPERATOR_FUNCTIONS(CPUContext);
  QuantDecodeOp(const OperatorDef& operator_def, Workspace* ws)
      : Operator<CPUContext>(operator_def, ws) {}

  ~QuantDecodeOp() {}

  bool RunOnDevice() override {
    CAFFE_ENFORCE_GT(InputSize(), 1);
    // first input is the codebook
    CAFFE_ENFORCE_EQ(InputSize(), OutputSize() + 1);

    const auto& codebook = Input(0);
    CAFFE_ENFORCE(codebook.template IsType<float>(), codebook.meta().name());

    for (int i = 0; i < OutputSize(); i++) {
      auto& ci = Input(i + 1);
      auto* co = Output(i);

      DecodeGeneral(
          codebook,
          ci,
          nullptr,
          co,
          /*resizeOnly=*/QuantDecodeRun == QuantDecodeRunTy::RUN_ONCE &&
              hasRun_);
    }
    hasRun_ = true;
    return true;
  }

 private:
  bool hasRun_{false};
};

class QuantDecodeGradientOp final : public Operator<CPUContext> {
 public:
  USE_OPERATOR_FUNCTIONS(CPUContext);
  QuantDecodeGradientOp(const OperatorDef& operator_def, Workspace* ws)
      : Operator<CPUContext>(operator_def, ws) {}
  ~QuantDecodeGradientOp() {}

  bool RunOnDevice() override {
    // Inputs: 1 codebook, n tensors of codes, and n corresponding gradients.
    CAFFE_ENFORCE(InputSize() >= 3 && InputSize() % 2 == 1);
    const int num_code_tensors = (InputSize() - 1) / 2;
    CAFFE_ENFORCE_EQ(OutputSize(), 1);

    const auto& codebook = Input(0);
    CAFFE_ENFORCE(codebook.template IsType<float>(), codebook.meta().name());

    auto* gradient = Output(0);
    gradient->ResizeLike(codebook);
    auto* gradient_ptr = gradient->mutable_data<float>();
    std::fill(gradient_ptr, gradient_ptr + gradient->size(), 0);

    for (int i = 0; i < num_code_tensors; i++) {
      auto& codes_i = Input(i + 1);
      auto& output_gradient_i = Input(i + num_code_tensors + 1);
      DecodeGeneral(codebook, codes_i, &output_gradient_i, gradient, false);
    }
    return true;
  }
};

} // namespace caffe2
#endif // QUANT_DECODE_OP_H_
